LASSO regression using thresholded NMF corr genes?

Regress on NMF score

Pseudobulk Expression

pseudobulk_vst <- readRDS("../BALLbulk_Deconvolution/BALL_89pt_pseudobulk_vst.rds")
pseudobulk_vst
An object of class Seurat 
29105 features across 89 samples within 1 assay 
Active assay: RNA (29105 features, 0 variable features)
NMF_ptscores <- read_csv('../CompositionAnalysis/BALL_Composition_DevState_NMFscores.csv') %>% 
  left_join(read_csv('scBALL_IDconversion.csv')) %>%
  select(ID_Bulk = ID, contains('NMF')) %>%
  pivot_longer(-ID_Bulk, names_to = 'NMF', values_to = 'NMFscore') %>% 
  left_join(pseudobulk_vst@meta.data %>% select(ID_Bulk) %>% rownames_to_column('ID_scRNA') ) %>% 
  mutate(Lineage = NMF %>% str_replace('.*_',''), NMF = NMF %>% str_replace('_.*','')) %>% 
  select(Patient = ID_scRNA, NMF, Lineage, NMFscore) 

── Column specification ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
cols(
  TB = col_character(),
  NMF1_ProB = col_double(),
  NMF2_EarlyLymphoid = col_double(),
  NMF3_Erythroid = col_double(),
  NMF4_PreB = col_double(),
  NMF5_MatureB = col_double(),
  NMF6_MyeloidProg = col_double(),
  NMF7_pDC = col_double(),
  NMF8_HSCMPPLMPP = col_double(),
  NMF9_Monocyte = col_double(),
  NMF10_TNK = col_double()
)

── Column specification ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
cols(
  Patient = col_character(),
  TB = col_character(),
  ID = col_character(),
  Sample_ID = col_character()
)
Joining with `by = join_by(TB)`Joining with `by = join_by(ID_Bulk)`
NMF_ptscores
NMFcorr <- read_csv('NMF_GeneCorr_Thresholding.csv') %>% 
  filter(threshold == 'pass', qvalue < 0.01) %>% arrange(qvalue) %>% arrange(NMF)

── Column specification ───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
cols(
  NMF = col_character(),
  Gene = col_character(),
  pearson = col_double(),
  pvalue = col_double(),
  qvalue = col_double(),
  pos_K1_threshold = col_double(),
  neg_K1_threshold = col_double(),
  threshold = col_character()
)
NMFcorr

Nested Cross Validation by Feature set

Define feature space - load markers

LinDE_FDR01_genes <- read_csv('BALL_DEresults_NMF_Lineage.csv') %>% filter(padj < 0.01, stat > 0) %>% pull(Gene) %>% unique()

── Column specification ───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
cols(
  Gene = col_character(),
  baseMean = col_double(),
  log2FoldChange = col_double(),
  lfcSE = col_double(),
  stat = col_double(),
  pvalue = col_double(),
  padj = col_double(),
  Lineage = col_character()
)
LinDE_FDR05_genes <- read_csv('BALL_DEresults_NMF_Lineage.csv') %>% filter(padj < 0.05, stat > 0) %>% pull(Gene) %>% unique()

── Column specification ───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
cols(
  Gene = col_character(),
  baseMean = col_double(),
  log2FoldChange = col_double(),
  lfcSE = col_double(),
  stat = col_double(),
  pvalue = col_double(),
  padj = col_double(),
  Lineage = col_character()
)
BDevDE_FDR01_genes <- read_csv('../BDevelopment_Characterization/BDevelopment_CellType_DEresults.csv') %>% 
  filter(padj < 0.01, stat > 0) %>% pull(Gene) %>% unique()

── Column specification ───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
cols(
  Gene = col_character(),
  baseMean = col_double(),
  log2FoldChange = col_double(),
  lfcSE = col_double(),
  stat = col_double(),
  pvalue = col_double(),
  padj = col_double(),
  CellType = col_character()
)
BDevDE_FDR05_genes <- read_csv('../BDevelopment_Characterization/BDevelopment_CellType_DEresults.csv') %>% 
  filter(padj < 0.05, stat > 0) %>% pull(Gene) %>% unique()

── Column specification ───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
cols(
  Gene = col_character(),
  baseMean = col_double(),
  log2FoldChange = col_double(),
  lfcSE = col_double(),
  stat = col_double(),
  pvalue = col_double(),
  padj = col_double(),
  CellType = col_character()
)
NMFcorr <- NMFcorr %>% mutate(LinDE_FDR05 = Gene %in% LinDE_FDR05_genes, 
                              LinDE_FDR01 = Gene %in% LinDE_FDR01_genes, 
                              BDevDE_FDR01 = Gene %in% BDevDE_FDR01_genes, 
                              BDevDE_FDR05 = Gene %in% BDevDE_FDR05_genes)
NMFcorr
evaluate_model <- function(model, x_val, anno_val, lambda, 
                           feature_name, iteration, foldname){

  # Create score classification with survival and get covariates
  pred_y <- predict(model, x_val, s = lambda) %>% data.frame()
  colnames(pred_y) <- 'PredScore'
  pred_y <- pred_y %>% rownames_to_column('Patient') %>% 
    # add anno to get covariates
    left_join(anno_val, by = 'Patient')
  
  # Calculate correlation in validation set
  pearson <- cor(pred_y$PredScore, pred_y$NMFscore, method = 'pearson')
  spearman <- cor(pred_y$PredScore, pred_y$NMFscore, method = 'spearman')
  
  # Summary Metrics
  summary_metrics <- data.frame(
    'model_id' = paste0(feature_name, '_iter', iteration, '_', foldname),
    'lambda' = lambda,
    'model_size' = sum(coef(model, s = lambda)!=0),
    'pearson' = pearson,
    'spearman' = spearman,
    'features' = feature_name,
    'iteration' = iteration,
    'foldname' = foldname
  )
  return(summary_metrics)
}
gridsearch_lasso <- function(expr_train, expr_val, anno_train, anno_val, features, feature_name,
                             iteration, foldname, summary_metrics){
  
  # Filter expr matrix for feature set
  x_train <- expr_train[, colnames(expr_train) %in% features]
  x_val <- expr_val[, colnames(expr_val) %in% features]

  # Train LASSO 
  model <- train_LASSO(x_train, anno_train)

  # Get summary metrics for lambda.min and lambda.1se
  for(lambda in c('lambda.min', 'lambda.1se')){
    summary_metrics <- summary_metrics %>% rbind(
      evaluate_model(model = model, x_val = x_val, anno_val = anno_val, lambda = lambda, 
                     feature_name = feature_name, iteration = iteration, foldname = foldname))
  }
  
  return(summary_metrics)
}
nestedCV_regression <- function(train_anno, train_expr, iteration, feature_sets, summary_metrics){
  # set up random seed and shuffle data 
  set.seed(iteration)
  train_anno <- train_anno[sample(nrow(train_anno)),]
  train_expr <- train_expr[sample(nrow(train_expr)),]
  
  ## 5-fold outer cross validation
  folds <- rsample::vfold_cv(train_anno, 5)
  for(outer_cv in 1:5){
    # fold ID
    foldname <- folds$id[[outer_cv]]
    # get anno splits
    anno_train <- analysis(folds$splits[[outer_cv]])
    anno_val <- assessment(folds$splits[[outer_cv]])
    # get expr splits
    expr_train <- train_expr[anno_train$Patient,]
    expr_val <- train_expr[anno_val$Patient,]
    
    # Iterate through feature set and run gridsearch to train survival functions
    for(feature_name in names(feature_sets)){
      # get feature list
      features <- feature_sets[[feature_name]]
      # run gridsearch and get results
      summary_metrics <- gridsearch_lasso(expr_train = expr_train, expr_val = expr_val, anno_train = anno_train, anno_val = anno_val, 
                                  features = features, feature_name = feature_name, iteration = iteration, foldname = foldname,
                                  summary_metrics = summary_metrics)
    }
  }
  return(summary_metrics)
}
library(tidymodels)
library(glmnet)

output <- data.frame()
train_x <- pseudobulk_vst@assays$RNA@data[,unique(NMF_ptscores$Patient)] %>% data.matrix() %>% t()

for(NMFcomp in c('NMF1', 'NMF2', 'NMF3', 'NMF4', 'NMF5', 'NMF6', 'NMF7', 'NMF8', 'NMF9', 'NMF10')){
  
  print(paste0('NMF Component: ', NMFcomp))
  
  temp_output <- data.frame()
  
  train_y <- NMF_ptscores %>% filter(NMF == NMFcomp) %>% select(Patient, NMFscore)
  featurespace <- list('PosCorr_LinDE_FDR05' = NMFcorr %>% filter(NMF == NMFcomp, LinDE_FDR05 == TRUE, pearson > 0) %>% pull(Gene), 
                       'PosCorr_LinDE_FDR01' = NMFcorr %>% filter(NMF == NMFcomp, LinDE_FDR01 == TRUE, pearson > 0) %>% pull(Gene), 
                       'PosCorr_LinDE_BDevDE_FDR05' = NMFcorr %>% filter(NMF == NMFcomp, LinDE_FDR05 == TRUE, BDevDE_FDR05 == TRUE, pearson > 0) %>% pull(Gene), 
                       'PosCorr_LinDE_BDevDE_FDR01' = NMFcorr %>% filter(NMF == NMFcomp, LinDE_FDR01 == TRUE, BDevDE_FDR01 == TRUE, pearson > 0) %>% pull(Gene), 
                       'AnyCorr_LinDE_FDR05' = NMFcorr %>% filter(NMF == NMFcomp, LinDE_FDR05 == TRUE) %>% pull(Gene),
                       'AnyCorr_LinDE_FDR01' = NMFcorr %>% filter(NMF == NMFcomp, LinDE_FDR01 == TRUE) %>% pull(Gene),
                       'AnyCorr_LinDE_BDevDE_FDR05' = NMFcorr %>% filter(NMF == NMFcomp, LinDE_FDR05 == TRUE, BDevDE_FDR05 == TRUE) %>% pull(Gene), 
                       'AnyCorr_LinDE_BDevDE_FDR01' = NMFcorr %>% filter(NMF == NMFcomp, LinDE_FDR01 == TRUE, BDevDE_FDR01 == TRUE) %>% pull(Gene)
  )
  
  for(iteration in 1:10){
    print(paste0('iteration ', iteration))
    temp_output <- nestedCV_regression(train_anno = train_y, train_expr = train_x, iteration = iteration, feature_sets = featurespace, 
                                       summary_metrics = temp_output) 
  }
  ## annotate and add to final output
  output <- bind_rows(output, temp_output %>% mutate(NMF = NMFcomp))
}

output %>% write_csv('RepNestedCV_results_NMFregression.csv')

5-Fold cross validation with 10 repeats within the pseudobulk to estimate the best parameters and get a gestalt of the overall accuracy After choosing the best combination of parameters we will test on the bulk RNA-seq dataset.

output <- read_csv('RepNestedCV_results_NMFregression.csv') 

── Column specification ───────────────────────────────────────────────────────────────────────────────────────────────────────
cols(
  model_id = col_character(),
  lambda = col_character(),
  model_size = col_double(),
  pearson = col_double(),
  spearman = col_double(),
  features = col_character(),
  iteration = col_double(),
  foldname = col_character(),
  NMF = col_character()
)
output %>% pull(features) %>% table()
.
AnyCorr_LinDE_BDevDE_FDR01 AnyCorr_LinDE_BDevDE_FDR05        AnyCorr_LinDE_FDR01        AnyCorr_LinDE_FDR05 
                      1000                       1000                       1000                       1000 
PosCorr_LinDE_BDevDE_FDR01 PosCorr_LinDE_BDevDE_FDR05        PosCorr_LinDE_FDR01        PosCorr_LinDE_FDR05 
                      1000                       1000                       1000                       1000 
output %>% 
  mutate(NMF = factor(NMF, levels = c('NMF1', 'NMF2', 'NMF3', 'NMF4', 'NMF5', 'NMF6', 'NMF7', 'NMF8', 'NMF9', 'NMF10'))) %>% 
  ggplot(aes(x = reorder(features, -model_size), y = model_size, fill = lambda)) + 
  geom_hline(yintercept = 10, lty = 2) + geom_hline(yintercept = 20, lty = 2) + 
  geom_hline(yintercept = 30, lty = 2) + geom_hline(yintercept = 40, lty = 2) + 
  geom_boxplot(outlier.size = 0.8) + ggbeeswarm::geom_quasirandom(dodge.width = 0.7, size = 0.2, alpha = 0.7) + 
  facet_wrap(.~NMF, ncol = 6) + theme_pubr() + 
  theme(axis.text.x = element_text(angle = 90, hjust = 1, vjust = 0.5)) + 
  stat_compare_means(label = 'p.signif')

NA
output %>% 
  mutate(NMF = factor(NMF, levels = c('NMF1', 'NMF2', 'NMF3', 'NMF4', 'NMF5', 'NMF6', 'NMF7', 'NMF8', 'NMF9', 'NMF10'))) %>% 
  ggplot(aes(x = reorder(features, -pearson), y = pearson, fill = lambda)) + 
  geom_boxplot(outlier.size = 0.8) + ggbeeswarm::geom_quasirandom(dodge.width = 0.7, size = 0.2, alpha = 0.7) + 
  facet_wrap(.~NMF, ncol = 6) + theme_pubr() + 
  theme(axis.text.x = element_text(angle = 90, hjust = 1, vjust = 0.5)) + 
  stat_compare_means(label = 'p.signif')

NA

In summary, using combination of positively and negatively correlated genes leads to better performance, particularly for: NMF1 (Pro-B) NMF5 (Erythroid) NMF11 (Naive T)

In terms of filtering of correlated genes; Lin DE BDev DE FDR < 0.01 as a filter consistently lead to the best performance.

Lambda choice of Minimum + 1SE resulted in model size reductions of nearly 20 genes without sacrificing performance by nested CV.

output_CVmedians <- output %>% filter(lambda == 'lambda.1se', features == 'AnyCorr_LinDE_BDevDE_FDR01') %>% #AnyCorr_LinDE_FDR01. AnyCorr_LinDE_BDevDE_FDR01
  select(NMF, model_size, pearson, spearman) %>% 
  group_by(NMF) %>% summarise_all(median) %>% arrange(pearson)

output_CVmedians 

Final Choice: Corr Threshold + Lineage DE FDR01; lambda min + 1SE

Train from both positively corr and any corr. If models from both approaches have negative coefficients, use any corr. Also make sure that the genes are present in the bulk RNAseq data

train_LASSO <- function(x_train, y_train, alpha = 1){
  
  train_y <- y_train$NMFscore
  
  # Perform Lasso regression with LOOCV 
  model <- cv.glmnet(x = x_train, y = train_y, nfold = dim(x_train)[1], family = 'gaussian', alpha = alpha, maxit=1000000, standardize=FALSE)
  #plot(model)

  return(model)
}
bulkRNAgenes <- data.table::fread("../BALLbulk_Deconvolution/BALL_bulkRNA_data/BALL_BulkRNAseq_subsetgene_rawcounts.txt")$Gene
bulkRNAgenes %>% length()
[1] 36110
set.seed(123)
model_list <- list()
# subset NMF corr with genes present in the bulk RNA data
NMFcorr <- NMFcorr %>% filter(Gene %in% bulkRNAgenes)

for(NMFcomp in c('NMF1', 'NMF2', 'NMF3', 'NMF4', 'NMF5', 'NMF6', 'NMF7', 'NMF8', 'NMF9', 'NMF10')){
  
  # Define y variable; NMF score
  train_y <- NMF_ptscores %>% filter(NMF == NMFcomp) %>% select(Patient, NMFscore)
  # Define feature space to train from
  feature_space <- NMFcorr %>% filter(NMF == NMFcomp, LinDE_FDR01 == TRUE, BDevDE_FDR01 == TRUE) %>% pull(Gene) #BDevDE_FDR01 == TRUE
  # subset training set
  train_x <- pseudobulk_vst@assays$RNA@data[feature_space, train_y$Patient] %>% data.matrix() %>% t()
  
  model <- train_LASSO(train_x, y_train = train_y)
  model_list[[NMFcomp]] <- model
}
Warning: Option grouped=FALSE enforced in cv.glmnet, since < 3 observations per foldWarning: Option grouped=FALSE enforced in cv.glmnet, since < 3 observations per foldWarning: Option grouped=FALSE enforced in cv.glmnet, since < 3 observations per foldWarning: Option grouped=FALSE enforced in cv.glmnet, since < 3 observations per foldWarning: Option grouped=FALSE enforced in cv.glmnet, since < 3 observations per foldWarning: Option grouped=FALSE enforced in cv.glmnet, since < 3 observations per foldWarning: Option grouped=FALSE enforced in cv.glmnet, since < 3 observations per foldWarning: Option grouped=FALSE enforced in cv.glmnet, since < 3 observations per foldWarning: Option grouped=FALSE enforced in cv.glmnet, since < 3 observations per foldWarning: Option grouped=FALSE enforced in cv.glmnet, since < 3 observations per fold
# model weights 
modelweights <- data.frame()

for(NMFcomp in c('NMF1', 'NMF2', 'NMF3', 'NMF4', 'NMF5', 'NMF6', 'NMF7', 'NMF8', 'NMF9', 'NMF10')){
  modelweights <- modelweights %>% bind_rows(
    model_list[[NMFcomp]] %>% coef(s = 'lambda.1se') %>% data.matrix() %>% 
      data.frame() %>% dplyr::rename(Weight = s1) %>% rownames_to_column('Gene') %>% 
      tail(-1) %>% filter(Weight != 0) %>% arrange(-Weight) %>% mutate(Model = NMFcomp)
  )
}

modelweights <- modelweights %>% select(Model, Gene, Weight)
modelweights %>% group_by(Model) %>% summarise(count = n())
# model weights 
modelweights <- data.frame()

for(NMFcomp in c('NMF1', 'NMF2', 'NMF3', 'NMF4', 'NMF5', 'NMF6', 'NMF7', 'NMF8', 'NMF9', 'NMF10')){
  modelweights <- modelweights %>% bind_rows(
    model_list[[NMFcomp]] %>% coef(s = 'lambda.1se') %>% data.matrix() %>% 
      data.frame() %>% dplyr::rename(Weight = s1) %>% rownames_to_column('Gene') %>% 
      tail(-1) %>% filter(Weight != 0) %>% arrange(-Weight) %>% mutate(Model = NMFcomp)
  )
}

modelweights <- modelweights %>% select(Model, Gene, Weight)
modelweights %>% group_by(Model) %>% summarise(count = n())
modelweights %>% 
  left_join(NMFconvert %>% dplyr::rename(Model = NMF)) %>%
  mutate(coefficient = ifelse(Weight > 0, 'Positive', 'Negative') %>% factor(levels = c('Positive', 'Negative'))) %>% 
  group_by(NMFnamed, coefficient) %>% summarise(count = n()) %>%
  ggplot(aes(x = NMFnamed, y = count, fill = coefficient)) + geom_col() + ggpubr::theme_pubr() + 
  ggsci::scale_fill_simpsons() + theme(axis.text.x = element_text(angle = 90, hjust = 1))
Joining with `by = join_by(Model)``summarise()` has grouped output by 'NMFnamed'. You can override using the `.groups` argument.

NMFnamed_levels <- c('HSC_MPP', 'Myeloid_Prog', 'Pre_pDC', 'Early_Lymphoid', 'Pro_B', 'Pre_B', 
                      'Mature_B', 'Erythroid', 'Monocyte', 'T_NK')

NMFconvert <- data.frame(
  'NMF' = c('NMF8', 'NMF6', 'NMF7', 'NMF2', 'NMF1', 'NMF4', 
            'NMF5', 'NMF3', 'NMF9', 'NMF10') %>% factor(),
  'NMFnamed' = NMFnamed_levels %>% factor(levels = NMFnamed_levels)
)

NMFconvert
modelweights <- modelweights %>% left_join(NMFconvert %>% dplyr::rename(Model = NMF)) %>% 
  arrange(Gene) %>% arrange(NMFnamed) %>% pivot_wider(id_cols=Gene, names_from=NMFnamed, values_from=Weight) %>% replace(is.na(.), 0) 
Joining with `by = join_by(Model)`
modelweights %>% write_csv("NMF_Lasso_ModelWeights.csv")
modelweights

Figure out NMF scoring and validate on pseudobulk

calculate_NMFscores = function(query, modelweights, scale = TRUE, sampleID = 'Patient'){
  
  # Check for overlap with model genes and query genes
  querygenes <- rownames(query)
  modelweights_missing <- sum(!(modelweights$Gene %in% querygenes))
  # check for missing genes
  if(modelweights_missing > 0){
    print(paste0('Warning: ', modelweights_missing, ' genes from NMF models are missing from query dataset'))
  }
  
  # filter model weights
  modelweights <- modelweights %>% filter(Gene %in% querygenes)
  modelweights_mat <- modelweights %>% column_to_rownames('Gene') %>% data.matrix()
  
  # multiply query by NMF lasso weights
  scored <- (t(query[modelweights$Gene,]) %*% modelweights_mat) %>% data.matrix() 
  if(scale == TRUE){
    scored <- scale(scored)
  }
  scored <- scored %>% as.data.frame() %>% rownames_to_column(sampleID) 
  
  return(scored)
}
modelweights <- read_csv("NMF_Lasso_ModelWeights.csv")

── Column specification ─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
cols(
  Gene = col_character(),
  HSC_MPP = col_double(),
  Myeloid_Prog = col_double(),
  Pre_pDC = col_double(),
  Early_Lymphoid = col_double(),
  Pro_B = col_double(),
  Pre_B = col_double(),
  Mature_B = col_double(),
  Erythroid = col_double(),
  Monocyte = col_double(),
  T_NK = col_double()
)
modelweights
pseudobulk_vst <- readRDS("../BALLbulk_Deconvolution/BALL_89pt_pseudobulk_vst.rds")
# calculate NMF scores from vst-normalized data
pseudobulk_vst_NMFscores <- calculate_NMFscores(pseudobulk_vst@assays$RNA@data, modelweights, scale = T, sampleID = 'Patient')
pseudobulk_vst_NMFscores[1:10,1:10]

Training results - pseudobulk

NMF_compare <- NMF_ptscores %>% 
  left_join(NMFconvert) %>% 
  left_join(pseudobulk_vst_NMFscores %>% pivot_longer(-Patient, names_to = 'NMFnamed', values_to = 'predNMF')) 
Joining with `by = join_by(NMF)`Joining with `by = join_by(Patient, NMFnamed)`
NMF_compare %>% 
  mutate(NMFnamed = factor(NMFnamed, levels = NMFnamed_levels)) %>% 
  ggplot(aes(x = predNMF, y = NMFscore)) + 
  geom_point() + geom_smooth(method = 'lm') + 
  facet_wrap(.~NMFnamed, scales = 'free') + 
  theme_pubr() + stat_cor()

True Bulk RNAseq “validation”

bulkRNA_counts <- data.table::fread("../BALLbulk_Deconvolution/BALL_bulkRNA_data/BALL_BulkRNAseq_subsetgene_rawcounts.txt") %>% 
  column_to_rownames('Gene') %>% data.matrix()
bulkRNA_counts[1:10,1:10]
          SJBALL020608_D1 SJBALL082_D SJINF074_D SJBALL209_D SJINF049_D SJALL050848_D1 SJBALL016239_D1 SJBALL205_D SJBALL016303_D1 SJBALL016244_D1
A1BG                  143         317        182         122         82            252            1090         298              78              51
A1BG-AS1              147         771        279         326        199            307             921         552             119             111
A1CF                    1           2          1           0          0              0               0           0               0               1
A2M                    11           8         20          75         15              4               4           8               1              88
A2M-AS1                 6         198          5         489         26             43              15          38              23              36
A2ML1                   0           0          2           0          1              2               0           0               0               0
A2ML1-AS1               0           1          0           0          0              0               0           0               0               0
A2ML1-AS2               0           0          0           0          0              0               1           0               0               0
A3GALT2                 2           1          2           0          2              2               0           0               5               0
A4GALT                 59          32         60          60         34              4              37          18               2              44
library(DESeq2)
bulkRNA_dds <- DESeqDataSetFromMatrix(bulkRNA_counts[rowSums(bulkRNA_counts) >= 10,], 
                       colData = data.frame('Patient' = colnames(bulkRNA_counts)) %>% column_to_rownames('Patient'), 
                       design = ~1)
bulkRNA_vst <- assay(vst(bulkRNA_dds))
rm(bulkRNA_counts, bulkRNA_dds)

bulkRNA_vst[1:10,1:10]
          SJBALL020608_D1 SJBALL082_D SJINF074_D SJBALL209_D SJINF049_D SJALL050848_D1 SJBALL016239_D1 SJBALL205_D
A1BG             7.140496    7.459732   7.069288    6.857087   6.789549       8.334186        9.654538    8.049401
A1BG-AS1         7.171814    8.579933   7.566172    8.013636   7.813683       8.592772        9.422367    8.855233
A1CF             4.393532    4.420817   4.350597    4.088563   4.088563       4.088563        4.088563    4.088563
A2M              5.082162    4.748766   5.231875    6.364534   5.376806       4.830286        4.673378    4.948642
A2M-AS1          4.828826    6.923254   4.671326    8.537353   5.748445       6.318600        5.201054    5.874671
A2ML1            4.088563    4.088563   4.458630    4.088563   4.431539       4.615895        4.088563    4.088563
A2ML1-AS1        4.088563    4.323761   4.088563    4.088563   4.088563       4.088563        4.088563    4.088563
A2ML1-AS2        4.088563    4.088563   4.088563    4.088563   4.088563       4.088563        4.382468    4.088563
A3GALT2          4.519059    4.323761   4.458630    4.088563   4.572475       4.615895        4.088563    4.088563
A4GALT           6.232613    5.377247   5.982178    6.159561   5.959306       4.830286        5.780731    5.356654
          SJBALL016303_D1 SJBALL016244_D1
A1BG             7.804691        7.006957
A1BG-AS1         8.341231        7.928454
A1CF             4.088563        4.568334
A2M              4.632465        7.641534
A2M-AS1          6.442545        6.636279
A2ML1            4.088563        4.088563
A2ML1-AS1        4.088563        4.088563
A2ML1-AS2        4.088563        4.088563
A3GALT2          5.278005        4.088563
A4GALT           4.853334        6.846260
# calculate NMF scores from vst-normalized data
bulkRNA_vst_NMFscores <- calculate_NMFscores(bulkRNA_vst, modelweights, scale = T, sampleID = 'Patient')
bulkRNA_vst_NMFscores[1:10,1:10]

Validation results - matched bulkRNAseq

NMF_ptscores_converted <- NMF_ptscores %>% 
  left_join(read_delim("../BALL_metadata_20230105.txt", delim = '\t') %>% select(Patient = Directory, Sample, ID, TB)) 

── Column specification ─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
cols(
  Sample = col_character(),
  ID = col_character(),
  Directory = col_character(),
  TB = col_character(),
  age = col_double(),
  sex = col_character(),
  diagnosis = col_character(),
  subdiagnosis = col_character(),
  location = col_character()
)
Joining with `by = join_by(Patient)`
NMF_ptscores_compareBulkRNA <- NMF_ptscores_converted %>% 
  left_join(NMFconvert) %>% 
  left_join(bulkRNA_vst_NMFscores %>% pivot_longer(-Patient, names_to = 'NMFnamed', values_to = 'predNMF') %>% 
              dplyr::rename(ID = Patient)) 
Joining with `by = join_by(NMF)`Joining with `by = join_by(ID, NMFnamed)`
NMF_ptscores_compareBulkRNA
NMF_ptscores_compareBulkRNA %>% 
  mutate(NMFnamed = factor(NMFnamed, NMFnamed_levels)) %>% 
  ggplot(aes(x = predNMF, y = NMFscore)) + 
  geom_point() + geom_smooth(method = 'lm') + 
  facet_wrap(.~NMFnamed, scales = 'free', ncol = 5) + 
  theme_pubr() + stat_cor() + 
  ylab('NMF Lineage Score (scRNA composition)') + xlab('Predicted NMF Score (matched bulk RNA-seq)')

NMF_ptscores_compareBulkRNA %>% 
  mutate(NMFnamed = factor(NMFnamed, NMFnamed_levels)) %>% 
  ggplot(aes(x = predNMF, y = NMFscore)) + 
  geom_point() + geom_smooth(method = 'lm') + 
  facet_wrap(.~NMFnamed, scales = 'free', ncol = 2) + 
  theme_pubr() + stat_cor() + 
  ylab('NMF Lineage Score (scRNA composition)') + xlab('Predicted NMF Score (matched bulk RNA-seq)')

Not bad!! Save bulk RNAseq NMF scores:

bulkRNA_vst_NMFscores %>% write_csv('Bulk2046_NMFregression_LineageScores.csv')
bulkRNA_vst_NMFscores
output %>% 
  left_join(NMFconvert) %>%
  #mutate(NMF = factor(NMF, levels = c('NMF1', 'NMF2', 'NMF3', 'NMF4', 'NMF5', 'NMF6', 'NMF7', 'NMF8', 'NMF9', 'NMF10'))) %>% 
  filter(features == 'AnyCorr_LinDE_BDevDE_FDR01', lambda == 'lambda.1se') %>% 
  ggplot(aes(x = NMFnamed, y = pearson, fill = NMFnamed)) + 
  theme_pubr(legend = 'none') + geom_hline(yintercept = 0.5, lty = 2) + geom_hline(yintercept = 0.75, lty = 2) + geom_hline(yintercept = 0.9, lty = 2) + 
  geom_boxplot(outlier.size = 0) + ggbeeswarm::geom_quasirandom(size = 1, alpha = 0.7) + 
  theme(axis.text.x = element_text(angle = 90, hjust = 1, vjust = 0.5)) 
Joining with `by = join_by(NMF)`

NA
output %>% 
  left_join(NMFconvert) %>%
  #mutate(NMF = factor(NMF, levels = c('NMF1', 'NMF2', 'NMF3', 'NMF4', 'NMF5', 'NMF6', 'NMF7', 'NMF8', 'NMF9', 'NMF10'))) %>% 
  filter(features == 'AnyCorr_LinDE_BDevDE_FDR01', lambda == 'lambda.1se') %>% 
  ggplot(aes(x = reorder(NMFnamed, -pearson), y = pearson, fill = NMFnamed)) + 
  theme_pubr(legend = 'none') + geom_hline(yintercept = 0.5, lty = 2) + geom_hline(yintercept = 0.75, lty = 2) + geom_hline(yintercept = 0.9, lty = 2) + 
  geom_boxplot(outlier.size = 0) + ggbeeswarm::geom_quasirandom(size = 1, alpha = 0.7) + 
  theme(axis.text.x = element_text(angle = 90, hjust = 1, vjust = 0.5)) 
  
NMF_ptscores_compareBulkRNA %>% drop_na()

quickly evaluate on cord blood sorted

cb_fractions = data.table::fread('../../../../../../../Dormancy/HSC_analysis/HSC_byOntogeny/Dicklab_sortedFractions_Batch4_CB_Hierarchy_vst.csv')
cb_fractions <- cb_fractions %>% column_to_rownames('Gene') %>% data.matrix()
cb_fractions %>% dim()
[1] 43165   107
CBfractions_vst_NMFscores <- calculate_NMFscores(cb_fractions, modelweights, scale = T, sampleID = 'Sample')
[1] "Warning: 1 genes from NMF models are missing from query dataset"
CBfractions_vst_NMFscores
CBfractions_vst_NMFscores %>% 
  pivot_longer(-Sample, names_to = 'NMFsig', values_to = 'Score') %>% 
  mutate(NMFsig = factor(NMFsig, levels = NMFnamed_levels), 
         Population = Sample %>% str_replace('.*CB_',''),
         Population = factor(Population, levels = c('HSC', 'MPP', 'LMPP', 'CMP', 'GMP', 'MLPII', 'EarlyProB', 'PreProB', 
                                                    'ProB', 'PreB', 'B', #'B1', 'B2', 'B3', 'B4', 'B5', 'B6', 
                                                    'T', 'NK', 'EryP', 'Mono', 'Gr'))) %>% 
  filter(Population != 'NA') %>% 
  ggplot(aes(x = Population, y = Score, fill = Population)) + 
  geom_boxplot() + geom_jitter() +
  theme_pubr(legend = 'none') + theme(axis.text.x = element_text(angle = 90, hjust = 1, vjust = 0.5)) +
  facet_wrap(.~NMFsig, scales = 'free', ncol = 3)

NA
CBfractions_vst_NMFscores %>% 
  pivot_longer(-Sample, names_to = 'NMFsig', values_to = 'Score') %>% 
  mutate(NMFsig = factor(NMFsig, levels = NMFnamed_levels), 
         Population = Sample %>% str_replace('.*CB_',''),
         Population = factor(Population, levels = c('HSC', 'MPP', 'LMPP', 'CMP', 'GMP', 'MLPII', 'EarlyProB', 'PreProB', 
                                                    'ProB', 'PreB', 'B', #'B1', 'B2', 'B3', 'B4', 'B5', 'B6', 
                                                    'T', 'NK', 'EryP', 'Mono', 'Gr'))) %>% 

Evaluate on Pharmacotypes

# require gene symbol column to be named "Gene"
rpkm_to_logTPM <- function(dat){
  # convert to TPM
  dat_TPM <- dat %>% 
    gather(-Gene, key = "Sample", value = "RPKM") %>%
    group_by(Sample) %>% 
    mutate(logTPM = log1p(RPKM / sum(RPKM) * 1000000)) %>% 
    select(-RPKM) %>% ungroup() %>% 
    spread(Sample, logTPM)
  
  return(dat_TPM)
}
pharmacotype_fpkm <- data.table::fread('pharmacotypes/pharmacotyping_ped_rnaseq_fpkm_ALLids_0823.csv') %>% select(-GeneID) %>% dplyr::rename(Gene = GeneName)
pharmacotype_fpkm
pharmacotype_logTPM <- pharmacotype_fpkm %>% rpkm_to_logTPM()
pharmacotype_logTPM <- pharmacotype_logTPM %>% column_to_rownames('Gene') %>% data.matrix()
pharmacotype_logTPM %>% dim()
pharmacotypes <- read_csv('pharmacotypes/ALL_invitro_pharmacotypes.csv')

── Column specification ─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
cols(
  .default = col_double(),
  `Patient ID` = col_character(),
  `Sample ID` = col_character(),
  Immunophenotype = col_character(),
  `Molecular subtype` = col_character(),
  Protocol = col_character(),
  `NCI risk` = col_character(),
  Sex = col_character(),
  `Population and ancestry` = col_character()
)
ℹ Use `spec()` for the full column specifications.
pharmacotypes
pharmacotypes_combined <- pharmacotypes %>% inner_join(pharmacotype_logTPM_scored) %>% filter(Immunophenotype == 'B')
Joining with `by = join_by(`Patient ID`)`
pharmacotypes_combined
CellType_Drug_Corr <- cor(x = select(pharmacotypes_combined, contains('normalized')), 
                          y = select(pharmacotypes_combined, HSC_MPP, Myeloid_Prog, Pre_pDC, Early_Lymphoid, 
                                     Pro_B, Pre_B, Mature_B, T_NK, Monocyte, Erythroid), use = 'pairwise.complete.obs', method = 'spearman')
CellType_Drug_Corr
                              HSC_MPP Myeloid_Prog      Pre_pDC Early_Lymphoid        Pro_B        Pre_B     Mature_B        T_NK     Monocyte   Erythroid
Asparaginase_normalized    0.06297156  0.146949854 -0.011408241    0.029215760 -0.176086803  0.043973855 -0.028931055 -0.02380685  0.074471547 -0.09803874
Bortezomib_normalized      0.10425820 -0.156163247 -0.009683717   -0.099246543 -0.032780758  0.015565485  0.091815461  0.05255612 -0.164868343  0.01398804
CHZ868_normalized         -0.23378972  0.143426764 -0.365216511   -0.128265675 -0.088951192 -0.322730680 -0.135304022  0.24603521  0.271579908  0.18831668
Cytarabine_normalized      0.06574424  0.037662561  0.025593659    0.157748568 -0.001227337 -0.176277524 -0.068655134 -0.08831712 -0.092632208 -0.03560725
Dasatinib_normalized       0.08106764  0.021310534  0.071770981    0.104709385  0.023748006 -0.170917984 -0.001895330 -0.05798105 -0.090192980  0.09435230
Daunorubicin_normalized    0.11223808  0.073212138  0.093254818   -0.055860016 -0.118538604 -0.127861052  0.094619932  0.01121027 -0.009758818 -0.05619371
Dexamethasone_normalized   0.01083389  0.004212854  0.099537893   -0.070950751 -0.153948693 -0.182048537  0.134455023  0.13229506  0.101970249  0.05366476
Ibrutinib_normalized       0.07157877  0.027072835 -0.130261628    0.005968476 -0.154157001 -0.021891568  0.118207788  0.14282596  0.031242657  0.08075572
Mercaptopurine_normalized  0.11074588  0.111776761  0.032717251    0.054946846 -0.152427880 -0.130912251  0.030001120 -0.02326335  0.004634170  0.01969312
Nelarabine_normalized      0.03055041 -0.187116251  0.161731174    0.122817203  0.071404728 -0.122710150  0.004737120 -0.01666021 -0.102182628  0.19340565
Panobinostat_normalized   -0.06571385  0.115933478 -0.046417098   -0.028606394 -0.146441137 -0.336305426  0.022880745  0.05598899  0.152997225  0.20260495
Prednisolone_normalized    0.20934634  0.059269053  0.250990414    0.123834369 -0.092154364 -0.036719634 -0.029667418 -0.06413731 -0.040582133 -0.04349867
Ruxolitinib_normalized     0.11996152  0.010825715  0.012268038   -0.096602513  0.022206807  0.083861990 -0.014207715 -0.07556449 -0.080703799  0.02473502
Thioguanine_normalized     0.27711897  0.066801038  0.179078618    0.218979572 -0.003125489  0.005995744 -0.116854086 -0.14908658 -0.185575536 -0.08553664
Trametinib_normalized     -0.24754694 -0.137116487 -0.336145912   -0.033905251  0.063186014 -0.096718034 -0.028086933  0.16363125  0.029472491  0.25015189
Venetoclax_normalized     -0.18051023  0.057766556 -0.409731802   -0.026677120 -0.063144189 -0.055577985  0.001113826  0.20009403  0.239062327  0.07513834
Vincristine_normalized     0.04590763  0.040889885  0.102042820    0.097484450  0.097688176  0.006290842 -0.058296676 -0.07866391 -0.134378552 -0.06877864
Vorinostat_normalized     -0.04506071 -0.040089676 -0.087037205    0.085285161 -0.034263351 -0.272137053  0.025057343  0.05707103  0.027493618  0.10202289
CellType_Drug_Corr
library(corrplot)
CellType_Drug_Corr %>% t() %>% corrplot()

Jae Kim B-ALL Ph+

Kim_PhALL_samples <- list.files('subtype_subcluster/Kim2023_Ph_BALL/RNAseq_rawcounts/')
Kim_PhALL_samples
 [1] "JAMLR_0003_nn_P_count_sub.txt" "JAMLR_0004_nn_P_count_sub.txt" "JAMLR_0005_nn_P_count_sub.txt" "JAMLR_0006_nn_P_count_sub.txt"
 [5] "JAMLR_0007_nn_P_count_sub.txt" "JAMLR_0008_nn_P_count_sub.txt" "JAMLR_0009_nn_P_count_sub.txt" "JAMLR_0010_nn_P_count_sub.txt"
 [9] "JAMLR_0011_nn_P_count_sub.txt" "JAMLR_0012_nn_P_count_sub.txt" "JAMLR_0013_nn_P_count_sub.txt" "JAMLR_0014_nn_M_count_sub.txt"
[13] "JAMLR_0014_nn_P_count_sub.txt" "JAMLR_0015_nn_P_count_sub.txt" "JAMLR_0016_nn_P_count_sub.txt" "JAMLR_0017_nn_M_count_sub.txt"
[17] "JAMLR_0017_nn_P_count_sub.txt" "JAMLR_0018_nn_P_count_sub.txt" "JAMLR_0019_nn_M_count_sub.txt" "JAMLR_0019_nn_P_count_sub.txt"
[21] "JAMLR_0020_nn_P_count_sub.txt" "JAMLR_0021_nn_P_count_sub.txt" "JAMLR_0022_nn_P_count_sub.txt" "JAMLR_0023_nn_P_count_sub.txt"
[25] "JAMLR_0024_nn_P_count_sub.txt" "JAMLR_0025_nn_P_count_sub.txt" "JAMLR_0026_nn_P_count_sub.txt" "JAMLR_0027_nn_P_count_sub.txt"
[29] "JAMLR_0028_nn_P_count_sub.txt" "JAMLR_0029_nn_P_count_sub.txt" "JAMLR_0030_nn_P_count_sub.txt" "JAMLR_0031_nn_P_count_sub.txt"
[33] "JAMLR_0032_nn_P_count_sub.txt" "JAMLR_0033_nn_P_count_sub.txt" "JAMLR_0034_nn_P_count_sub.txt" "JAMLR_0035_nn_P_count_sub.txt"
[37] "JAMLR_0036_nn_P_count_sub.txt" "JAMLR_0037_nn_P_count_sub.txt" "JAMLR_0038_nn_P_count_sub.txt" "JAMLR_0039_nn_P_count_sub.txt"
[41] "JAMLR_0040_nn_P_count_sub.txt" "JAMLR_0041_nn_P_count_sub.txt" "JAMLR_0042_nn_P_count_sub.txt" "JAMLR_0043_nn_P_count_sub.txt"
[45] "JAMLR_0044_nn_P_count_sub.txt" "JAMLR_0045_nn_P_count_sub.txt" "JAMLR_0046_nn_P_count_sub.txt" "JAMLR_0047_nn_P_count_sub.txt"
[49] "JAMLR_0048_nn_P_count_sub.txt" "JAMLR_0049_nn_P_count_sub.txt" "JAMLR_0050_nn_P_count_sub.txt" "JAMLR_0051_nn_P_count_sub.txt"
[53] "JAMLR_0052_nn_P_count_sub.txt" "JAMLR_0053_nn_P_count_sub.txt" "JAMLR_0054_nn_P_count_sub.txt" "JAMLR_0055_nn_M_count_sub.txt"
[57] "JAMLR_0055_nn_P_count_sub.txt"
Kim_PhALL_anno <- read_csv('subtype_subcluster/Kim2023_Ph_BALL/Ph_BALL_ClinicalAnno.csv')

── Column specification ───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
cols(
  .default = col_character(),
  Cohort = col_double(),
  Age_at_dx = col_double(),
  Age_at_dx_rounded = col_double(),
  OS = col_double(),
  alive = col_double(),
  RFS = col_double(),
  relapse_or_death = col_double(),
  OS_BMT_censored = col_double(),
  alive_BMT_censored = col_double(),
  WBC = col_double()
)
ℹ Use `spec()` for the full column specifications.
Kim_PhALL_anno
kim_phALL_counts <- data.table::fread(paste0('subtype_subcluster/Kim2023_Ph_BALL/RNAseq_rawcounts/', Kim_PhALL_samples[1])) 
colnames(kim_phALL_counts) <- c('ENSG', Kim_PhALL_samples[1] %>% str_replace('_count_sub.txt',''))

for(ph_samp in Kim_PhALL_samples[-1]){
  print(ph_samp)
  # Load file 
  temp <- data.table::fread(paste0('subtype_subcluster/Kim2023_Ph_BALL/RNAseq_rawcounts/', ph_samp))
  colnames(temp) <- c('ENSG', ph_samp %>% str_replace('_count_sub.txt','')) 
  # merge
  kim_phALL_counts <- kim_phALL_counts %>% left_join(temp)
}
[1] "JAMLR_0004_nn_P_count_sub.txt"
Joining with `by = join_by(ENSG)`
[1] "JAMLR_0005_nn_P_count_sub.txt"
Joining with `by = join_by(ENSG)`
[1] "JAMLR_0006_nn_P_count_sub.txt"
Joining with `by = join_by(ENSG)`
[1] "JAMLR_0007_nn_P_count_sub.txt"
Joining with `by = join_by(ENSG)`
[1] "JAMLR_0008_nn_P_count_sub.txt"
Joining with `by = join_by(ENSG)`
[1] "JAMLR_0009_nn_P_count_sub.txt"
Joining with `by = join_by(ENSG)`
[1] "JAMLR_0010_nn_P_count_sub.txt"
Joining with `by = join_by(ENSG)`
[1] "JAMLR_0011_nn_P_count_sub.txt"
Joining with `by = join_by(ENSG)`
[1] "JAMLR_0012_nn_P_count_sub.txt"
Joining with `by = join_by(ENSG)`
[1] "JAMLR_0013_nn_P_count_sub.txt"
Joining with `by = join_by(ENSG)`
[1] "JAMLR_0014_nn_M_count_sub.txt"
Joining with `by = join_by(ENSG)`
[1] "JAMLR_0014_nn_P_count_sub.txt"
Joining with `by = join_by(ENSG)`
[1] "JAMLR_0015_nn_P_count_sub.txt"
Joining with `by = join_by(ENSG)`
[1] "JAMLR_0016_nn_P_count_sub.txt"
Joining with `by = join_by(ENSG)`
[1] "JAMLR_0017_nn_M_count_sub.txt"
Joining with `by = join_by(ENSG)`
[1] "JAMLR_0017_nn_P_count_sub.txt"
Joining with `by = join_by(ENSG)`
[1] "JAMLR_0018_nn_P_count_sub.txt"
Joining with `by = join_by(ENSG)`
[1] "JAMLR_0019_nn_M_count_sub.txt"
Joining with `by = join_by(ENSG)`
[1] "JAMLR_0019_nn_P_count_sub.txt"
Joining with `by = join_by(ENSG)`
[1] "JAMLR_0020_nn_P_count_sub.txt"
Joining with `by = join_by(ENSG)`
[1] "JAMLR_0021_nn_P_count_sub.txt"
Joining with `by = join_by(ENSG)`
[1] "JAMLR_0022_nn_P_count_sub.txt"
Joining with `by = join_by(ENSG)`
[1] "JAMLR_0023_nn_P_count_sub.txt"
Joining with `by = join_by(ENSG)`
[1] "JAMLR_0024_nn_P_count_sub.txt"
Joining with `by = join_by(ENSG)`
[1] "JAMLR_0025_nn_P_count_sub.txt"
Joining with `by = join_by(ENSG)`
[1] "JAMLR_0026_nn_P_count_sub.txt"
Joining with `by = join_by(ENSG)`
[1] "JAMLR_0027_nn_P_count_sub.txt"
Joining with `by = join_by(ENSG)`
[1] "JAMLR_0028_nn_P_count_sub.txt"
Joining with `by = join_by(ENSG)`
[1] "JAMLR_0029_nn_P_count_sub.txt"
Joining with `by = join_by(ENSG)`
[1] "JAMLR_0030_nn_P_count_sub.txt"
Joining with `by = join_by(ENSG)`
[1] "JAMLR_0031_nn_P_count_sub.txt"
Joining with `by = join_by(ENSG)`
[1] "JAMLR_0032_nn_P_count_sub.txt"
Joining with `by = join_by(ENSG)`
[1] "JAMLR_0033_nn_P_count_sub.txt"
Joining with `by = join_by(ENSG)`
[1] "JAMLR_0034_nn_P_count_sub.txt"
Joining with `by = join_by(ENSG)`
[1] "JAMLR_0035_nn_P_count_sub.txt"
Joining with `by = join_by(ENSG)`
[1] "JAMLR_0036_nn_P_count_sub.txt"
Joining with `by = join_by(ENSG)`
[1] "JAMLR_0037_nn_P_count_sub.txt"
Joining with `by = join_by(ENSG)`
[1] "JAMLR_0038_nn_P_count_sub.txt"
Joining with `by = join_by(ENSG)`
[1] "JAMLR_0039_nn_P_count_sub.txt"
Joining with `by = join_by(ENSG)`
[1] "JAMLR_0040_nn_P_count_sub.txt"
Joining with `by = join_by(ENSG)`
[1] "JAMLR_0041_nn_P_count_sub.txt"
Joining with `by = join_by(ENSG)`
[1] "JAMLR_0042_nn_P_count_sub.txt"
Joining with `by = join_by(ENSG)`
[1] "JAMLR_0043_nn_P_count_sub.txt"
Joining with `by = join_by(ENSG)`
[1] "JAMLR_0044_nn_P_count_sub.txt"
Joining with `by = join_by(ENSG)`
[1] "JAMLR_0045_nn_P_count_sub.txt"
Joining with `by = join_by(ENSG)`
[1] "JAMLR_0046_nn_P_count_sub.txt"
Joining with `by = join_by(ENSG)`
[1] "JAMLR_0047_nn_P_count_sub.txt"
Joining with `by = join_by(ENSG)`
[1] "JAMLR_0048_nn_P_count_sub.txt"
Joining with `by = join_by(ENSG)`
[1] "JAMLR_0049_nn_P_count_sub.txt"
Joining with `by = join_by(ENSG)`
[1] "JAMLR_0050_nn_P_count_sub.txt"
Joining with `by = join_by(ENSG)`
[1] "JAMLR_0051_nn_P_count_sub.txt"
Joining with `by = join_by(ENSG)`
[1] "JAMLR_0052_nn_P_count_sub.txt"
Joining with `by = join_by(ENSG)`
[1] "JAMLR_0053_nn_P_count_sub.txt"
Joining with `by = join_by(ENSG)`
[1] "JAMLR_0054_nn_P_count_sub.txt"
Joining with `by = join_by(ENSG)`
[1] "JAMLR_0055_nn_M_count_sub.txt"
Warning: Stopped early on line 57774. Expected 2 fields but found 3. Consider fill=TRUE and comment.char=. First discarded non-empty line: <<__no_feature       733049>>Joining with `by = join_by(ENSG)`
[1] "JAMLR_0055_nn_P_count_sub.txt"
Joining with `by = join_by(ENSG)`
kim_phALL_counts
ENSGconvert <- data.table::fread('../../../../../../../CIBERSORT/newDataSets_Jul2020/preprocessing/GRCh38_transcript_lengths.txt')
ENSGconvert <- ENSGconvert %>% select(ENSG = V1, Gene = V2) %>% unique()
ENSGconvert
kim_phALL_counts <- kim_phALL_counts %>% inner_join(ENSGconvert) %>% select(-ENSG) %>% select(Gene, everything()) %>% 
  group_by(Gene) %>% summarise_all(sum)
Joining with `by = join_by(ENSG)`
kim_phALL_counts
kim_phALL_counts  %>% write_csv('subtype_subcluster/Kim2023_Ph_BALL/Kim2023_Ph_BALL_RNAseq_counts.csv')
library(DESeq2)

kim_phALL_vst <- DESeqDataSetFromMatrix(kim_phALL_counts %>% column_to_rownames('Gene') %>% data.matrix() , 
                       colData = data.frame('Patient' = colnames(kim_phALL_counts)[-1]) %>% column_to_rownames('Patient'), 
                       design = ~1) %>% vst() %>% assay()
rm(kim_phALL_counts)

kim_phALL_vst[1:10,1:10]
          JAMLR_0003_nn_P JAMLR_0004_nn_P JAMLR_0005_nn_P JAMLR_0006_nn_P JAMLR_0007_nn_P JAMLR_0008_nn_P JAMLR_0009_nn_P JAMLR_0010_nn_P JAMLR_0011_nn_P
5S_rRNA          6.403128        6.403128        6.403128        6.403128        6.403128        6.403128        6.403128        6.403128        6.403128
7SK             13.180453       12.333339       11.877076       10.784501       11.035531       10.744278       11.414756       11.357774       10.528394
A1BG             6.541529        6.533009        6.610118        6.664031        6.614425        6.403128        6.528962        6.548396        6.606195
A1BG-AS1         7.733125        8.601058        8.214180        9.511347        7.858343        7.275169        8.648863        8.115911        8.337477
A1CF             6.403128        6.403128        6.403128        6.403128        6.403128        6.403128        6.403128        6.403128        6.403128
A2M             10.222182       10.538301        6.403128        7.063431       11.429050       10.965017        7.799619        8.368576        9.622616
A2M-AS1          7.400566        7.084282        6.886712        6.814816        7.249977        6.976037        7.262451        6.812803        6.403128
A2ML1            6.403128        6.403128        6.403128        6.403128        6.403128        6.403128        6.403128        6.403128        6.403128
A2ML1-AS1        6.403128        6.403128        6.403128        6.403128        6.403128        6.403128        6.403128        6.403128        6.403128
A2ML1-AS2        6.403128        6.403128        6.403128        6.403128        6.403128        6.403128        6.403128        6.403128        6.403128
          JAMLR_0012_nn_P
5S_rRNA          6.403128
7SK             11.256461
A1BG             6.707218
A1BG-AS1         8.305869
A1CF             6.403128
A2M              7.226245
A2M-AS1          6.927905
A2ML1            6.403128
A2ML1-AS1        6.403128
A2ML1-AS2        6.403128
kim_phALL_vst %>% as.data.frame() %>% rownames_to_column('Gene') %>% write_csv('subtype_subcluster/Kim2023_Ph_BALL/Kim2023_Ph_BALL_RNAseq_vst.csv')
# calculate NMF scores from vst-normalized data
kim_phALL_vst_NMFscores <- calculate_NMFscores(kim_phALL_vst, modelweights, scale = T, sampleID = 'Sample')
[1] "Warning: 7 genes from NMF models are missing from query dataset"
kim_phALL_vst_NMFscores[1:10,1:10]
kim_phALL_vst_NMFscores %>% write_csv('subtype_subcluster/Kim2023_Ph_BALL/Kim2023_Ph_BALL_LineageNMF_Scored.csv')

compare

Kim_PhALL_anno %>% mutate(Sample = ifelse(Manuscript_name %>% str_detect('-R'), paste0(JAMLR, '_nn_M'), paste0(JAMLR, '_nn_P'))) %>% 
  write_csv('subtype_subcluster/Kim2023_Ph_BALL/Kim2023_Ph_BALL_anno_cleaned.csv')
Kim_PhALL_anno_LineageScored <- Kim_PhALL_anno %>% mutate(Sample = ifelse(Manuscript_name %>% str_detect('-R'), paste0(JAMLR, '_nn_M'), paste0(JAMLR, '_nn_P'))) %>% 
  select(Sample, WBC, Age_at_dx, subtype, Subgroup) %>% 
  inner_join(kim_phALL_vst_NMFscores)
Joining with `by = join_by(Sample)`
Kim_PhALL_anno_LineageScored
Kim_PhALL_anno %>% pull(Age_at_dx) %>% summary()
   Min. 1st Qu.  Median    Mean 3rd Qu.    Max.    NA's 
  13.00   41.25   51.00   51.23   62.75   88.00       6 
Kim_PhALL_anno_LineageScored %>% 
  select(-Sample, -WBC, -Age_at_dx, -subtype) %>% 
  pivot_longer(-Subgroup) %>% 
  ggplot(aes(x = Subgroup, y = value, fill = Subgroup)) + 
  geom_boxplot() + ggbeeswarm::geom_quasirandom() +
  theme_pubr(legend = 'none') + theme(axis.text.x = element_text(angle = 90, hjust = 1, vjust = 0.5)) +
  facet_wrap(.~name, scales = 'free', ncol = 5) + stat_compare_means()

---
title: "LASSO Regression"
output: html_notebook
---

LASSO regression using thresholded NMF corr genes? 

## Regress on NMF score

Pseudobulk Expression

```{r}
pseudobulk_vst <- readRDS("../BALLbulk_Deconvolution/BALL_89pt_pseudobulk_vst.rds")
pseudobulk_vst
```


```{r}
NMF_ptscores <- read_csv('../CompositionAnalysis/BALL_Composition_DevState_NMFscores.csv') %>% 
  left_join(read_csv('scBALL_IDconversion.csv')) %>%
  select(ID_Bulk = ID, contains('NMF')) %>%
  pivot_longer(-ID_Bulk, names_to = 'NMF', values_to = 'NMFscore') %>% 
  left_join(pseudobulk_vst@meta.data %>% select(ID_Bulk) %>% rownames_to_column('ID_scRNA') ) %>% 
  mutate(Lineage = NMF %>% str_replace('.*_',''), NMF = NMF %>% str_replace('_.*','')) %>% 
  select(Patient = ID_scRNA, NMF, Lineage, NMFscore) 

NMF_ptscores
```

```{r}
NMFcorr <- read_csv('NMF_GeneCorr_Thresholding.csv') %>% 
  filter(threshold == 'pass', qvalue < 0.01) %>% arrange(qvalue) %>% arrange(NMF)
NMFcorr
```

### Nested Cross Validation by Feature set

Define feature space - load markers 


```{r}
LinDE_FDR01_genes <- read_csv('BALL_DEresults_NMF_Lineage.csv') %>% filter(padj < 0.01, stat > 0) %>% pull(Gene) %>% unique()
LinDE_FDR05_genes <- read_csv('BALL_DEresults_NMF_Lineage.csv') %>% filter(padj < 0.05, stat > 0) %>% pull(Gene) %>% unique()
BDevDE_FDR01_genes <- read_csv('../BDevelopment_Characterization/BDevelopment_CellType_DEresults.csv') %>% 
  filter(padj < 0.01, stat > 0) %>% pull(Gene) %>% unique()
BDevDE_FDR05_genes <- read_csv('../BDevelopment_Characterization/BDevelopment_CellType_DEresults.csv') %>% 
  filter(padj < 0.05, stat > 0) %>% pull(Gene) %>% unique()
```

```{r}
NMFcorr <- NMFcorr %>% mutate(LinDE_FDR05 = Gene %in% LinDE_FDR05_genes, 
                              LinDE_FDR01 = Gene %in% LinDE_FDR01_genes, 
                              BDevDE_FDR01 = Gene %in% BDevDE_FDR01_genes, 
                              BDevDE_FDR05 = Gene %in% BDevDE_FDR05_genes)
NMFcorr
```

```{r}
evaluate_model <- function(model, x_val, anno_val, lambda, 
                           feature_name, iteration, foldname){

  # Create score classification with survival and get covariates
  pred_y <- predict(model, x_val, s = lambda) %>% data.frame()
  colnames(pred_y) <- 'PredScore'
  pred_y <- pred_y %>% rownames_to_column('Patient') %>% 
    # add anno to get covariates
    left_join(anno_val, by = 'Patient')
  
  # Calculate correlation in validation set
  pearson <- cor(pred_y$PredScore, pred_y$NMFscore, method = 'pearson')
  spearman <- cor(pred_y$PredScore, pred_y$NMFscore, method = 'spearman')
  
  # Summary Metrics
  summary_metrics <- data.frame(
    'model_id' = paste0(feature_name, '_iter', iteration, '_', foldname),
    'lambda' = lambda,
    'model_size' = sum(coef(model, s = lambda)!=0),
    'pearson' = pearson,
    'spearman' = spearman,
    'features' = feature_name,
    'iteration' = iteration,
    'foldname' = foldname
  )
  return(summary_metrics)
}

```


```{r}
gridsearch_lasso <- function(expr_train, expr_val, anno_train, anno_val, features, feature_name,
                             iteration, foldname, summary_metrics){
  
  # Filter expr matrix for feature set
  x_train <- expr_train[, colnames(expr_train) %in% features]
  x_val <- expr_val[, colnames(expr_val) %in% features]

  # Train LASSO 
  model <- train_LASSO(x_train, anno_train)

  # Get summary metrics for lambda.min and lambda.1se
  for(lambda in c('lambda.min', 'lambda.1se')){
    summary_metrics <- summary_metrics %>% rbind(
      evaluate_model(model = model, x_val = x_val, anno_val = anno_val, lambda = lambda, 
                     feature_name = feature_name, iteration = iteration, foldname = foldname))
  }
  
  return(summary_metrics)
}

```


```{r}
nestedCV_regression <- function(train_anno, train_expr, iteration, feature_sets, summary_metrics){
  # set up random seed and shuffle data 
  set.seed(iteration)
  train_anno <- train_anno[sample(nrow(train_anno)),]
  train_expr <- train_expr[sample(nrow(train_expr)),]
  
  ## 5-fold outer cross validation
  folds <- rsample::vfold_cv(train_anno, 5)
  for(outer_cv in 1:5){
    # fold ID
    foldname <- folds$id[[outer_cv]]
    # get anno splits
    anno_train <- analysis(folds$splits[[outer_cv]])
    anno_val <- assessment(folds$splits[[outer_cv]])
    # get expr splits
    expr_train <- train_expr[anno_train$Patient,]
    expr_val <- train_expr[anno_val$Patient,]
    
    # Iterate through feature set and run gridsearch to train survival functions
    for(feature_name in names(feature_sets)){
      # get feature list
      features <- feature_sets[[feature_name]]
      # run gridsearch and get results
      summary_metrics <- gridsearch_lasso(expr_train = expr_train, expr_val = expr_val, anno_train = anno_train, anno_val = anno_val, 
                                  features = features, feature_name = feature_name, iteration = iteration, foldname = foldname,
                                  summary_metrics = summary_metrics)
    }
  }
  return(summary_metrics)
}

```


```{r}
library(tidymodels)
library(glmnet)

output <- data.frame()
train_x <- pseudobulk_vst@assays$RNA@data[,unique(NMF_ptscores$Patient)] %>% data.matrix() %>% t()

for(NMFcomp in c('NMF1', 'NMF2', 'NMF3', 'NMF4', 'NMF5', 'NMF6', 'NMF7', 'NMF8', 'NMF9', 'NMF10')){
  
  print(paste0('NMF Component: ', NMFcomp))
  
  temp_output <- data.frame()
  
  train_y <- NMF_ptscores %>% filter(NMF == NMFcomp) %>% select(Patient, NMFscore)
  featurespace <- list('PosCorr_LinDE_FDR05' = NMFcorr %>% filter(NMF == NMFcomp, LinDE_FDR05 == TRUE, pearson > 0) %>% pull(Gene), 
                       'PosCorr_LinDE_FDR01' = NMFcorr %>% filter(NMF == NMFcomp, LinDE_FDR01 == TRUE, pearson > 0) %>% pull(Gene), 
                       'PosCorr_LinDE_BDevDE_FDR05' = NMFcorr %>% filter(NMF == NMFcomp, LinDE_FDR05 == TRUE, BDevDE_FDR05 == TRUE, pearson > 0) %>% pull(Gene), 
                       'PosCorr_LinDE_BDevDE_FDR01' = NMFcorr %>% filter(NMF == NMFcomp, LinDE_FDR01 == TRUE, BDevDE_FDR01 == TRUE, pearson > 0) %>% pull(Gene), 
                       'AnyCorr_LinDE_FDR05' = NMFcorr %>% filter(NMF == NMFcomp, LinDE_FDR05 == TRUE) %>% pull(Gene),
                       'AnyCorr_LinDE_FDR01' = NMFcorr %>% filter(NMF == NMFcomp, LinDE_FDR01 == TRUE) %>% pull(Gene),
                       'AnyCorr_LinDE_BDevDE_FDR05' = NMFcorr %>% filter(NMF == NMFcomp, LinDE_FDR05 == TRUE, BDevDE_FDR05 == TRUE) %>% pull(Gene), 
                       'AnyCorr_LinDE_BDevDE_FDR01' = NMFcorr %>% filter(NMF == NMFcomp, LinDE_FDR01 == TRUE, BDevDE_FDR01 == TRUE) %>% pull(Gene)
  )
  
  for(iteration in 1:10){
    print(paste0('iteration ', iteration))
    temp_output <- nestedCV_regression(train_anno = train_y, train_expr = train_x, iteration = iteration, feature_sets = featurespace, 
                                       summary_metrics = temp_output) 
  }
  ## annotate and add to final output
  output <- bind_rows(output, temp_output %>% mutate(NMF = NMFcomp))
}

output %>% write_csv('RepNestedCV_results_NMFregression.csv')
```


**5-Fold cross validation with 10 repeats within the pseudobulk to estimate the best parameters and get a gestalt of the overall accuracy**
After choosing the best combination of parameters we will test on the bulk RNA-seq dataset.


```{r}
output <- read_csv('RepNestedCV_results_NMFregression.csv') 
output %>% pull(features) %>% table()
```

```{r, fig.height = 6, fig.width = 12}
output %>% 
  mutate(NMF = factor(NMF, levels = c('NMF1', 'NMF2', 'NMF3', 'NMF4', 'NMF5', 'NMF6', 'NMF7', 'NMF8', 'NMF9', 'NMF10'))) %>% 
  ggplot(aes(x = reorder(features, -model_size), y = model_size, fill = lambda)) + 
  geom_hline(yintercept = 10, lty = 2) + geom_hline(yintercept = 20, lty = 2) + 
  geom_hline(yintercept = 30, lty = 2) + geom_hline(yintercept = 40, lty = 2) + 
  geom_boxplot(outlier.size = 0.8) + ggbeeswarm::geom_quasirandom(dodge.width = 0.7, size = 0.2, alpha = 0.7) + 
  facet_wrap(.~NMF, ncol = 6) + theme_pubr() + 
  theme(axis.text.x = element_text(angle = 90, hjust = 1, vjust = 0.5)) + 
  stat_compare_means(label = 'p.signif')
  
```


```{r, fig.height = 6, fig.width = 12}
output %>% 
  mutate(NMF = factor(NMF, levels = c('NMF1', 'NMF2', 'NMF3', 'NMF4', 'NMF5', 'NMF6', 'NMF7', 'NMF8', 'NMF9', 'NMF10'))) %>% 
  ggplot(aes(x = reorder(features, -pearson), y = pearson, fill = lambda)) + 
  geom_boxplot(outlier.size = 0.8) + ggbeeswarm::geom_quasirandom(dodge.width = 0.7, size = 0.2, alpha = 0.7) + 
  facet_wrap(.~NMF, ncol = 6) + theme_pubr() + 
  theme(axis.text.x = element_text(angle = 90, hjust = 1, vjust = 0.5)) + 
  stat_compare_means(label = 'p.signif')
  
```


```{r}
NMFconvert
```


In summary, using combination of positively and negatively correlated genes leads to better performance, particularly for: 
  NMF1 (Pro-B)
  NMF5 (Erythroid)
  NMF11 (Naive T)
  
In terms of filtering of correlated genes; Lin DE BDev DE FDR < 0.01 as a filter consistently lead to the best performance. 

Lambda choice of Minimum + 1SE resulted in model size reductions of nearly 20 genes without sacrificing performance by nested CV. 

```{r}
output_CVmedians <- output %>% filter(lambda == 'lambda.1se', features == 'AnyCorr_LinDE_BDevDE_FDR01') %>% #AnyCorr_LinDE_FDR01. AnyCorr_LinDE_BDevDE_FDR01
  select(NMF, model_size, pearson, spearman) %>% 
  group_by(NMF) %>% summarise_all(median) %>% arrange(pearson)

output_CVmedians 
```



### Final Choice: Corr Threshold + Lineage DE FDR01; lambda min + 1SE

Train from both positively corr and any corr. If models from both approaches have negative coefficients, use any corr. 
Also make sure that the genes are present in the bulk RNAseq data


```{r}
train_LASSO <- function(x_train, y_train, alpha = 1){
  
  train_y <- y_train$NMFscore
  
  # Perform Lasso regression with LOOCV 
  model <- cv.glmnet(x = x_train, y = train_y, nfold = dim(x_train)[1], family = 'gaussian', alpha = alpha, maxit=1000000, standardize=FALSE)
  #plot(model)

  return(model)
}
```

```{r}
bulkRNAgenes <- data.table::fread("../BALLbulk_Deconvolution/BALL_bulkRNA_data/BALL_BulkRNAseq_subsetgene_rawcounts.txt")$Gene
bulkRNAgenes %>% length()
```


```{r}
set.seed(123)
model_list <- list()
# subset NMF corr with genes present in the bulk RNA data
NMFcorr <- NMFcorr %>% filter(Gene %in% bulkRNAgenes)

for(NMFcomp in c('NMF1', 'NMF2', 'NMF3', 'NMF4', 'NMF5', 'NMF6', 'NMF7', 'NMF8', 'NMF9', 'NMF10')){
  
  # Define y variable; NMF score
  train_y <- NMF_ptscores %>% filter(NMF == NMFcomp) %>% select(Patient, NMFscore)
  # Define feature space to train from
  feature_space <- NMFcorr %>% filter(NMF == NMFcomp, LinDE_FDR01 == TRUE, BDevDE_FDR01 == TRUE) %>% pull(Gene) #BDevDE_FDR01 == TRUE
  # subset training set
  train_x <- pseudobulk_vst@assays$RNA@data[feature_space, train_y$Patient] %>% data.matrix() %>% t()
  
  model <- train_LASSO(train_x, y_train = train_y)
  model_list[[NMFcomp]] <- model
}
```

```{r}
# model weights 
modelweights <- data.frame()

for(NMFcomp in c('NMF1', 'NMF2', 'NMF3', 'NMF4', 'NMF5', 'NMF6', 'NMF7', 'NMF8', 'NMF9', 'NMF10')){
  modelweights <- modelweights %>% bind_rows(
    model_list[[NMFcomp]] %>% coef(s = 'lambda.1se') %>% data.matrix() %>% 
      data.frame() %>% dplyr::rename(Weight = s1) %>% rownames_to_column('Gene') %>% 
      tail(-1) %>% filter(Weight != 0) %>% arrange(-Weight) %>% mutate(Model = NMFcomp)
  )
}

modelweights <- modelweights %>% select(Model, Gene, Weight)
modelweights %>% group_by(Model) %>% summarise(count = n())
```


```{r}
# model weights 
modelweights <- data.frame()

for(NMFcomp in c('NMF1', 'NMF2', 'NMF3', 'NMF4', 'NMF5', 'NMF6', 'NMF7', 'NMF8', 'NMF9', 'NMF10')){
  modelweights <- modelweights %>% bind_rows(
    model_list[[NMFcomp]] %>% coef(s = 'lambda.1se') %>% data.matrix() %>% 
      data.frame() %>% dplyr::rename(Weight = s1) %>% rownames_to_column('Gene') %>% 
      tail(-1) %>% filter(Weight != 0) %>% arrange(-Weight) %>% mutate(Model = NMFcomp)
  )
}

modelweights <- modelweights %>% select(Model, Gene, Weight)
modelweights %>% group_by(Model) %>% summarise(count = n())
```


```{r}
modelweights %>% 
  left_join(NMFconvert %>% dplyr::rename(Model = NMF)) %>%
  mutate(coefficient = ifelse(Weight > 0, 'Positive', 'Negative') %>% factor(levels = c('Positive', 'Negative'))) %>% 
  group_by(NMFnamed, coefficient) %>% summarise(count = n()) %>%
  ggplot(aes(x = NMFnamed, y = count, fill = coefficient)) + geom_col() + ggpubr::theme_pubr() + 
  ggsci::scale_fill_simpsons() + theme(axis.text.x = element_text(angle = 90, hjust = 1))
```





```{r}
NMFnamed_levels <- c('HSC_MPP', 'Myeloid_Prog', 'Pre_pDC', 'Early_Lymphoid', 'Pro_B', 'Pre_B', 
                      'Mature_B', 'Erythroid', 'Monocyte', 'T_NK')

NMFconvert <- data.frame(
  'NMF' = c('NMF8', 'NMF6', 'NMF7', 'NMF2', 'NMF1', 'NMF4', 
            'NMF5', 'NMF3', 'NMF9', 'NMF10') %>% factor(),
  'NMFnamed' = NMFnamed_levels %>% factor(levels = NMFnamed_levels)
)

NMFconvert
```

```{r}
modelweights <- modelweights %>% left_join(NMFconvert %>% dplyr::rename(Model = NMF)) %>% 
  arrange(Gene) %>% arrange(NMFnamed) %>% pivot_wider(id_cols=Gene, names_from=NMFnamed, values_from=Weight) %>% replace(is.na(.), 0) 

modelweights %>% write_csv("NMF_Lasso_ModelWeights.csv")
modelweights
```

## Figure out NMF scoring and validate on pseudobulk

```{r}
calculate_NMFscores = function(query, modelweights, scale = TRUE, sampleID = 'Patient'){
  
  # Check for overlap with model genes and query genes
  querygenes <- rownames(query)
  modelweights_missing <- sum(!(modelweights$Gene %in% querygenes))
  # check for missing genes
  if(modelweights_missing > 0){
    print(paste0('Warning: ', modelweights_missing, ' genes from NMF models are missing from query dataset'))
  }
  
  # filter model weights
  modelweights <- modelweights %>% filter(Gene %in% querygenes)
  modelweights_mat <- modelweights %>% column_to_rownames('Gene') %>% data.matrix()
  
  # multiply query by NMF lasso weights
  scored <- (t(query[modelweights$Gene,]) %*% modelweights_mat) %>% data.matrix() 
  if(scale == TRUE){
    scored <- scale(scored)
  }
  scored <- scored %>% as.data.frame() %>% rownames_to_column(sampleID) 
  
  return(scored)
}
```


```{r}
modelweights <- read_csv("NMF_Lasso_ModelWeights.csv")
modelweights
```

```{r}
pseudobulk_vst <- readRDS("../BALLbulk_Deconvolution/BALL_89pt_pseudobulk_vst.rds")
# calculate NMF scores from vst-normalized data
pseudobulk_vst_NMFscores <- calculate_NMFscores(pseudobulk_vst@assays$RNA@data, modelweights, scale = T, sampleID = 'Patient')
pseudobulk_vst_NMFscores[1:10,1:10]
```

**Training results - pseudobulk**

```{r, fig.height = 5, fig.width = 12}
NMF_compare <- NMF_ptscores %>% 
  left_join(NMFconvert) %>% 
  left_join(pseudobulk_vst_NMFscores %>% pivot_longer(-Patient, names_to = 'NMFnamed', values_to = 'predNMF')) 

NMF_compare %>% 
  mutate(NMFnamed = factor(NMFnamed, levels = NMFnamed_levels)) %>% 
  ggplot(aes(x = predNMF, y = NMFscore)) + 
  geom_point() + geom_smooth(method = 'lm') + 
  facet_wrap(.~NMFnamed, scales = 'free') + 
  theme_pubr() + stat_cor()
```




### True Bulk RNAseq "validation"

```{r}
bulkRNA_counts <- data.table::fread("../BALLbulk_Deconvolution/BALL_bulkRNA_data/BALL_BulkRNAseq_subsetgene_rawcounts.txt") %>% 
  column_to_rownames('Gene') %>% data.matrix()
bulkRNA_counts[1:10,1:10]
```

```{r}
library(DESeq2)
bulkRNA_dds <- DESeqDataSetFromMatrix(bulkRNA_counts[rowSums(bulkRNA_counts) >= 10,], 
                       colData = data.frame('Patient' = colnames(bulkRNA_counts)) %>% column_to_rownames('Patient'), 
                       design = ~1)
bulkRNA_vst <- assay(vst(bulkRNA_dds))
rm(bulkRNA_counts, bulkRNA_dds)

bulkRNA_vst[1:10,1:10]
```

```{r}
# calculate NMF scores from vst-normalized data
bulkRNA_vst_NMFscores <- calculate_NMFscores(bulkRNA_vst, modelweights, scale = T, sampleID = 'Patient')
bulkRNA_vst_NMFscores[1:10,1:10]
```

**Validation results - matched bulkRNAseq**

```{r}
NMF_ptscores_converted <- NMF_ptscores %>% 
  left_join(read_delim("../BALL_metadata_20230105.txt", delim = '\t') %>% select(Patient = Directory, Sample, ID, TB)) 

NMF_ptscores_compareBulkRNA <- NMF_ptscores_converted %>% 
  left_join(NMFconvert) %>% 
  left_join(bulkRNA_vst_NMFscores %>% pivot_longer(-Patient, names_to = 'NMFnamed', values_to = 'predNMF') %>% 
              dplyr::rename(ID = Patient)) 

NMF_ptscores_compareBulkRNA
```


```{r, fig.height = 3.5, fig.width = 12}
NMF_ptscores_compareBulkRNA %>% 
  mutate(NMFnamed = factor(NMFnamed, NMFnamed_levels)) %>% 
  ggplot(aes(x = predNMF, y = NMFscore)) + 
  geom_point() + geom_smooth(method = 'lm') + 
  facet_wrap(.~NMFnamed, scales = 'free', ncol = 5) + 
  theme_pubr() + stat_cor() + 
  ylab('NMF Lineage Score (scRNA composition)') + xlab('Predicted NMF Score (matched bulk RNA-seq)')
```



```{r, fig.height = 6, fig.width = 4}
NMF_ptscores_compareBulkRNA %>% 
  mutate(NMFnamed = factor(NMFnamed, NMFnamed_levels)) %>% 
  ggplot(aes(x = predNMF, y = NMFscore)) + 
  geom_point() + geom_smooth(method = 'lm') + 
  facet_wrap(.~NMFnamed, scales = 'free', ncol = 2) + 
  theme_pubr() + stat_cor() + 
  ylab('NMF Lineage Score (scRNA composition)') + xlab('Predicted NMF Score (matched bulk RNA-seq)')
```

Not bad!!
Save bulk RNAseq NMF scores: 

```{r}
bulkRNA_vst_NMFscores %>% write_csv('Bulk2046_NMFregression_LineageScores.csv')
bulkRNA_vst_NMFscores
```


```{r}
output %>% 
  left_join(NMFconvert) %>%
  #mutate(NMF = factor(NMF, levels = c('NMF1', 'NMF2', 'NMF3', 'NMF4', 'NMF5', 'NMF6', 'NMF7', 'NMF8', 'NMF9', 'NMF10'))) %>% 
  filter(features == 'AnyCorr_LinDE_BDevDE_FDR01', lambda == 'lambda.1se') %>% 
  ggplot(aes(x = NMFnamed, y = pearson, fill = NMFnamed)) + 
  theme_pubr(legend = 'none') + geom_hline(yintercept = 0.5, lty = 2) + geom_hline(yintercept = 0.75, lty = 2) + geom_hline(yintercept = 0.9, lty = 2) + 
  geom_boxplot(outlier.size = 0) + ggbeeswarm::geom_quasirandom(size = 1, alpha = 0.7) + 
  theme(axis.text.x = element_text(angle = 90, hjust = 1, vjust = 0.5)) 
  
```


```{r}
output %>% 
  left_join(NMFconvert) %>%
  #mutate(NMF = factor(NMF, levels = c('NMF1', 'NMF2', 'NMF3', 'NMF4', 'NMF5', 'NMF6', 'NMF7', 'NMF8', 'NMF9', 'NMF10'))) %>% 
  filter(features == 'AnyCorr_LinDE_BDevDE_FDR01', lambda == 'lambda.1se') %>% 
  ggplot(aes(x = reorder(NMFnamed, -pearson), y = pearson, fill = NMFnamed)) + 
  theme_pubr(legend = 'none') + geom_hline(yintercept = 0.5, lty = 2) + geom_hline(yintercept = 0.75, lty = 2) + geom_hline(yintercept = 0.9, lty = 2) + 
  geom_boxplot(outlier.size = 0) + ggbeeswarm::geom_quasirandom(size = 1, alpha = 0.7) + 
  theme(axis.text.x = element_text(angle = 90, hjust = 1, vjust = 0.5)) 
  
```

```{r}
NMF_ptscores_compareBulkRNA %>% drop_na() %>% 
  group_by()
```






## quickly evaluate on cord blood sorted

```{r}
cb_fractions = data.table::fread('../../../../../../../Dormancy/HSC_analysis/HSC_byOntogeny/Dicklab_sortedFractions_Batch4_CB_Hierarchy_vst.csv')
cb_fractions <- cb_fractions %>% column_to_rownames('Gene') %>% data.matrix()
cb_fractions %>% dim()
```


```{r}
CBfractions_vst_NMFscores <- calculate_NMFscores(cb_fractions, modelweights, scale = T, sampleID = 'Sample')
CBfractions_vst_NMFscores
```

```{r, fig.height = 8, fig.width = 10}
CBfractions_vst_NMFscores %>% 
  pivot_longer(-Sample, names_to = 'NMFsig', values_to = 'Score') %>% 
  mutate(NMFsig = factor(NMFsig, levels = NMFnamed_levels), 
         Population = Sample %>% str_replace('.*CB_',''),
         Population = factor(Population, levels = c('HSC', 'MPP', 'LMPP', 'CMP', 'GMP', 'MLPII', 'EarlyProB', 'PreProB', 
                                                    'ProB', 'PreB', 'B', #'B1', 'B2', 'B3', 'B4', 'B5', 'B6', 
                                                    'T', 'NK', 'EryP', 'Mono', 'Gr'))) %>% 
  filter(Population != 'NA') %>% 
  ggplot(aes(x = Population, y = Score, fill = Population)) + 
  geom_boxplot() + geom_jitter() +
  theme_pubr(legend = 'none') + theme(axis.text.x = element_text(angle = 90, hjust = 1, vjust = 0.5)) +
  facet_wrap(.~NMFsig, scales = 'free', ncol = 3)
  
```

```{r}
CBfractions_vst_NMFscores %>% 
  pivot_longer(-Sample, names_to = 'NMFsig', values_to = 'Score') %>% 
  mutate(NMFsig = factor(NMFsig, levels = NMFnamed_levels), 
         Population = Sample %>% str_replace('.*CB_',''),
         Population = factor(Population, levels = c('HSC', 'MPP', 'LMPP', 'CMP', 'GMP', 'MLPII', 'EarlyProB', 'PreProB', 
                                                    'ProB', 'PreB', 'B', #'B1', 'B2', 'B3', 'B4', 'B5', 'B6', 
                                                    'T', 'NK', 'EryP', 'Mono', 'Gr'))) %>% 
```



## Evaluate on Pharmacotypes 

```{r}
# require gene symbol column to be named "Gene"
rpkm_to_logTPM <- function(dat){
  # convert to TPM
  dat_TPM <- dat %>% 
    gather(-Gene, key = "Sample", value = "RPKM") %>%
    group_by(Sample) %>% 
    mutate(logTPM = log1p(RPKM / sum(RPKM) * 1000000)) %>% 
    select(-RPKM) %>% ungroup() %>% 
    spread(Sample, logTPM)
  
  return(dat_TPM)
}
```


```{r}
pharmacotype_fpkm <- data.table::fread('pharmacotypes/pharmacotyping_ped_rnaseq_fpkm_ALLids_0823.csv') %>% select(-GeneID) %>% dplyr::rename(Gene = GeneName)
pharmacotype_fpkm
```

```{r}
pharmacotype_logTPM <- pharmacotype_fpkm %>% rpkm_to_logTPM()
pharmacotype_logTPM <- pharmacotype_logTPM %>% column_to_rownames('Gene') %>% data.matrix()
pharmacotype_logTPM %>% dim()
```


```{r}
pharmacotype_logTPM_scored <- calculate_NMFscores(pharmacotype_logTPM, modelweights, scale = T, sampleID = 'Patient ID')
pharmacotype_logTPM_scored %>% write_csv('pharmacotypes/ALL_pharmacotypes_logTPM_DevState_scores.csv')
pharmacotype_logTPM_scored
```


```{r}
pharmacotypes <- read_csv('pharmacotypes/ALL_invitro_pharmacotypes.csv')
pharmacotypes
```


```{r}
pharmacotypes_combined <- pharmacotypes %>% inner_join(pharmacotype_logTPM_scored) %>% filter(Immunophenotype == 'B')
pharmacotypes_combined
```

```{r}
CellType_Drug_Corr <- cor(x = select(pharmacotypes_combined, contains('normalized')), 
                          y = select(pharmacotypes_combined, HSC_MPP, Myeloid_Prog, Pre_pDC, Early_Lymphoid, 
                                     Pro_B, Pre_B, Mature_B, T_NK, Monocyte, Erythroid), use = 'pairwise.complete.obs', method = 'spearman')
CellType_Drug_Corr
```

```{r}
CellType_Drug_Corr
```

```{r}
library(corrplot)
CellType_Drug_Corr %>% t() %>% corrplot()
```


## Jae Kim B-ALL Ph+

```{r}
Kim_PhALL_samples <- list.files('subtype_subcluster/Kim2023_Ph_BALL/RNAseq_rawcounts/')
Kim_PhALL_samples
```

```{r}
Kim_PhALL_anno <- read_csv('subtype_subcluster/Kim2023_Ph_BALL/Ph_BALL_ClinicalAnno.csv')
Kim_PhALL_anno
```

```{r}
kim_phALL_counts <- data.table::fread(paste0('subtype_subcluster/Kim2023_Ph_BALL/RNAseq_rawcounts/', Kim_PhALL_samples[1])) 
colnames(kim_phALL_counts) <- c('ENSG', Kim_PhALL_samples[1] %>% str_replace('_count_sub.txt',''))

for(ph_samp in Kim_PhALL_samples[-1]){
  print(ph_samp)
  # Load file 
  temp <- data.table::fread(paste0('subtype_subcluster/Kim2023_Ph_BALL/RNAseq_rawcounts/', ph_samp))
  colnames(temp) <- c('ENSG', ph_samp %>% str_replace('_count_sub.txt','')) 
  # merge
  kim_phALL_counts <- kim_phALL_counts %>% left_join(temp)
}

kim_phALL_counts
```

```{r}
ENSGconvert <- data.table::fread('../../../../../../../CIBERSORT/newDataSets_Jul2020/preprocessing/GRCh38_transcript_lengths.txt')
ENSGconvert <- ENSGconvert %>% select(ENSG = V1, Gene = V2) %>% unique()
ENSGconvert
```

```{r}
kim_phALL_counts <- kim_phALL_counts %>% inner_join(ENSGconvert) %>% select(-ENSG) %>% select(Gene, everything()) %>% 
  group_by(Gene) %>% summarise_all(sum)
kim_phALL_counts
```

```{r}
kim_phALL_counts %>% write_csv('subtype_subcluster/Kim2023_Ph_BALL/Kim2023_Ph_BALL_RNAseq_counts.csv')
```


```{r}
library(DESeq2)

kim_phALL_vst <- DESeqDataSetFromMatrix(kim_phALL_counts %>% column_to_rownames('Gene') %>% data.matrix(), 
                       colData = data.frame('Patient' = colnames(kim_phALL_counts)[-1]) %>% column_to_rownames('Patient'), 
                       design = ~1) %>% vst() %>% assay()
rm(kim_phALL_counts)

kim_phALL_vst[1:10,1:10]
```

```{r}
kim_phALL_vst %>% as.data.frame() %>% rownames_to_column('Gene') %>% write_csv('subtype_subcluster/Kim2023_Ph_BALL/Kim2023_Ph_BALL_RNAseq_vst.csv')
```


```{r}
# calculate NMF scores from vst-normalized data
kim_phALL_vst_NMFscores <- calculate_NMFscores(kim_phALL_vst, modelweights, scale = T, sampleID = 'Sample')
kim_phALL_vst_NMFscores[1:10,1:10]
```

```{r}
kim_phALL_vst_NMFscores %>% write_csv('subtype_subcluster/Kim2023_Ph_BALL/Kim2023_Ph_BALL_LineageNMF_Scored.csv')
```


# compare

```{r}
Kim_PhALL_anno %>% mutate(Sample = ifelse(Manuscript_name %>% str_detect('-R'), paste0(JAMLR, '_nn_M'), paste0(JAMLR, '_nn_P'))) %>% 
  write_csv('subtype_subcluster/Kim2023_Ph_BALL/Kim2023_Ph_BALL_anno_cleaned.csv')
```


```{r}
Kim_PhALL_anno_LineageScored <- Kim_PhALL_anno %>% mutate(Sample = ifelse(Manuscript_name %>% str_detect('-R'), paste0(JAMLR, '_nn_M'), paste0(JAMLR, '_nn_P'))) %>% 
  select(Sample, WBC, Age_at_dx, subtype, Subgroup) %>% 
  inner_join(kim_phALL_vst_NMFscores)

Kim_PhALL_anno_LineageScored
```

```{r}
Kim_PhALL_anno %>% pull(Age_at_dx) %>% summary()
```


```{r, fig.height = 5, fig.width = 12}
Kim_PhALL_anno_LineageScored %>% 
  select(-Sample, -WBC, -Age_at_dx, -subtype) %>% 
  pivot_longer(-Subgroup) %>% 
  ggplot(aes(x = Subgroup, y = value, fill = Subgroup)) + 
  geom_boxplot() + ggbeeswarm::geom_quasirandom() +
  theme_pubr(legend = 'none') + theme(axis.text.x = element_text(angle = 90, hjust = 1, vjust = 0.5)) +
  facet_wrap(.~name, scales = 'free', ncol = 5) + stat_compare_means()
```





